import torch
import torch.nn as nn
from PIL import Image

class CompressionAttack:
    """
    This class implements the compression attack on a model.
    """

    def __init__(self, model,device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
        """
        Initialize the CompressionAttack class.

        :param model: The model to be attacked.
        """
        try:
            self.model = model[0]
        except IndexError:
            raise ValueError("Model must be a sequential model with model[0] being the defense.")
        self.device = device
        self.model.to(self.device)


    def apply_attack(self):
        """
        Apply the compression attack on the model.
        """
        # Implement the logic for applying the compression attack
        pass

    def __call__(self, *args, **kwds):
        return self.apply_attack(*args, **kwds)

    def _save_images(self, images, path,amount = 1):
        """
        Save the images to the specified path.

        :param images: The images to be saved.
        :param path: The path where the images will be saved.
        """
        # Implement the logic for saving the images
        for i, image in enumerate(images[:amount]):
            image = image.cpu().detach().numpy()
            image = image.transpose(1, 2, 0)
            image = (image * 255).astype('uint8')
            image = Image.fromarray(image)
            image.save(f"{path}/image_{self.eps*255:.0f}_{i}.png")

    
class SimpleCompressionAttack(CompressionAttack):
    """
    This class implements a simple compression attack on a model.
    """

    def __init__(self, model,steps= 20,eps= 8/255,device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
        """
        Initialize the SimpleCompressionAttack class.

        :param model: The model to be attacked.
        :param compression_factor: The factor by which to compress the model.
        """
        super().__init__(model,device)
        self.steps = steps
        self.eps = eps
    
    def apply_attack(self,images,labels=None):
        """
        Apply the compression attack on the model.

        :param images: The images to be attacked.
        """
        # Implement the logic for applying the compression attack
        original_images = images.clone()
        adversarial_noise = torch.empty_like(images).uniform_(
            -self.eps, self.eps
        )
        adversarial_noise = adversarial_noise.to(self.device)
        original_images = original_images.to(self.device)
        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam([adversarial_noise], lr=10.0)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

        for step in range(self.steps):
            optimizer.zero_grad()
            
            defaults = optimizer.param_groups[0].copy()
            defaults.pop('params')
            optimizer.param_groups = [{'params': [adversarial_noise],**defaults}]
            
            
            outputs = self.model(original_images + adversarial_noise)
            loss = -criterion(outputs, original_images)
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                adversarial_noise = torch.clamp(adversarial_noise, -self.eps, self.eps)
            adversarial_noise.requires_grad = True
            #print(f"Step {step+1}/{self.steps}, Loss: {loss.item()}")
            scheduler.step()

        # Ensure the adversarial noise is within the epsilon constraint
        max_eps = torch.max(torch.abs(adversarial_noise))
        #print(f"Max epsilon: {max_eps.item()}")

        assert torch.all(torch.abs(adversarial_noise) <= self.eps*1.01), "Adversarial noise exceeds epsilon constraint"
        outputs = self.model(torch.clamp(original_images + adversarial_noise, 0, 1))
        outputs = torch.clamp(outputs, 0, 1)
        """l1_distance = torch.linalg.vector_norm(outputs - original_images, ord=1,)
        l2_distance = torch.linalg.vector_norm(outputs - original_images, ord=2,)
        linf_distance = torch.linalg.vector_norm(outputs - original_images, ord=float('inf'))
        print(f"L1 distance: {l1_distance.item()}")
        print(f"L2 distance: {l2_distance.item()}")
        print(f"Linf distance: {linf_distance.item()}")"""

        adv_images = torch.clamp(original_images + adversarial_noise, 0, 1)

        return adv_images

class DiffCompressionAttack(CompressionAttack):
    """
    This class implements a simple compression attack on a model.
    """

    def __init__(self, model,steps= 20,eps= 8/255):
        """
        Initialize the SimpleCompressionAttack class.

        :param model: The model to be attacked.
        :param compression_factor: The factor by which to compress the model.
        """
        super().__init__(model)
        self.steps = steps
        self.eps = eps
    
    def apply_attack(self,images):
        """
        Apply the compression attack on the model.

        :param images: The images to be attacked.
        """
        # Implement the logic for applying the compression attack
        original_images = images.clone()
        adversarial_noise = torch.empty_like(images).uniform_(
            -self.eps, self.eps
        )
        adversarial_noise = adversarial_noise.to(self.device)
        original_images = original_images.to(self.device)
        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam([adversarial_noise], lr=10.0)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
        with torch.no_grad():
            transformed_images = self.model(original_images)
        for step in range(self.steps):
            optimizer.zero_grad()
            
            defaults = optimizer.param_groups[0].copy()
            defaults.pop('params')
            optimizer.param_groups = [{'params': [adversarial_noise],**defaults}]
            
            
            outputs = self.model(original_images + adversarial_noise)
            loss = -criterion(outputs, transformed_images)
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                adversarial_noise = torch.clamp(adversarial_noise, -self.eps, self.eps)
            adversarial_noise.requires_grad = True
            #print(f"Step {step+1}/{self.steps}, Loss: {loss.item()}")
            scheduler.step()

        # Ensure the adversarial noise is within the epsilon constraint
        max_eps = torch.max(torch.abs(adversarial_noise))
        print(f"Max epsilon: {max_eps.item()}")

        assert torch.all(torch.abs(adversarial_noise) <= self.eps*1.01), "Adversarial noise exceeds epsilon constraint"
        outputs = self.model(torch.clamp(original_images + adversarial_noise, 0, 1))
        outputs = torch.clamp(outputs, 0, 1)
        l1_distance = torch.linalg.vector_norm(outputs - original_images, ord=1,)
        l2_distance = torch.linalg.vector_norm(outputs - original_images, ord=2,)
        linf_distance = torch.linalg.vector_norm(outputs - original_images, ord=float('inf'))
        print(f"L1 distance: {l1_distance.item()}")
        print(f"L2 distance: {l2_distance.item()}")
        print(f"Linf distance: {linf_distance.item()}")

        self._save_images(outputs, "debug_images")
        
class BitrateCompressionAttack(CompressionAttack):
    """
    This class implements a bitrate compression attack on a model.
    """

    def __init__(self, model,steps= 20,eps= 8/255,device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
        """
        Initialize the BitrateCompressionAttack class.

        :param model: The model to be attacked.
        :param compression_factor: The factor by which to compress the model.
        """
        super().__init__(model,device)
        self.steps = steps
        self.eps = eps
    
    def apply_attack(self,images,labels=None):
        """
        Apply the compression attack on the model.

        :param images: The images to be attacked.
        """
        # Implement the logic for applying the compression attack
        original_images = images.clone()
        adversarial_noise = torch.empty_like(images).uniform_(
            -self.eps, self.eps
        )
        adversarial_noise = adversarial_noise.to(self.device)
        original_images = original_images.to(self.device)
        criterion = nn.L1Loss()
        optimizer = torch.optim.Adam([adversarial_noise], lr=1.0)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

        for step in range(self.steps):
            optimizer.zero_grad()
            
            defaults = optimizer.param_groups[0].copy()
            defaults.pop('params')
            optimizer.param_groups = [{'params': [adversarial_noise],**defaults}]
            
            
            y_bitrate,z_bitrate,outputs = self.model.get_bitrate(original_images + adversarial_noise)
            #print(f"y_bitrate: {y_bitrate}, z_bitrate: {z_bitrate}")
            loss =  -criterion(outputs, original_images)+ 0.01 * (y_bitrate.sum()*0.0001 + z_bitrate.sum()*0.001)
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                adversarial_noise = torch.clamp(adversarial_noise, -self.eps, self.eps)
            adversarial_noise.requires_grad = True
            #print(f"Step {step+1}/{self.steps}, Loss: {loss.item()}")
            scheduler.step()

        # Ensure the adversarial noise is within the epsilon constraint
        max_eps = torch.max(torch.abs(adversarial_noise))
        #print(f"Max epsilon: {max_eps.item()}")

        assert torch.all(torch.abs(adversarial_noise) <= self.eps*1.01), "Adversarial noise exceeds epsilon constraint"
        outputs = self.model(torch.clamp(original_images + adversarial_noise, 0, 1))
        outputs = torch.clamp(outputs, 0, 1)
        #l1_distance = torch.linalg.vector_norm(outputs - original_images, ord=1,)
        #l2_distance = torch.linalg.vector_norm(outputs - original_images, ord=2,)
        #linf_distance = torch.linalg.vector_norm(outputs - original_images, ord=float('inf'))
        #print(f"L1 distance: {l1_distance.item()}")
        #print(f"L2 distance: {l2_distance.item()}")
        #print(f"Linf distance: {linf_distance.item()}")

        #self._save_images(outputs, "debug_images")
        adv_images = torch.clamp(original_images + adversarial_noise, 0, 1)

        return adv_images
        


class OutsideCompressionAttack(CompressionAttack):
    """
    This class implements an outside compression attack on a model.
    """

    def __init__(self, model,steps= 20,eps= 8/255,outside_eps= 64/255):
        """
        Initialize the OutsideCompressionAttack class.

        :param model: The model to be attacked.
        :param compression_factor: The factor by which to compress the model.
        """
        super().__init__(model)
        self.steps = steps
        self.eps = eps
        self.outside_eps = outside_eps
        self.eps_step_size = (self.outside_eps-self.eps)/self.steps
    
    def apply_attack(self,images):
        """
        Apply the compression attack on the model.

        :param images: The images to be attacked.
        """
        # Implement the logic for applying the compression attack
        original_images = images.clone()
        adversarial_noise = torch.empty_like(images).uniform_(
            -self.outside_eps, self.outside_eps
        )
        adversarial_noise = adversarial_noise.to(self.device)
        original_images = original_images.to(self.device)
        criterion = nn.L1Loss()
        optimizer = torch.optim.Adam([adversarial_noise], lr=10.0)

        for step in range(5):
            optimizer.zero_grad()
            
            defaults = optimizer.param_groups[0].copy()
            defaults.pop('params')
            optimizer.param_groups = [{'params': [adversarial_noise],**defaults}]
            
            
            outputs = self.model(original_images + adversarial_noise)
            loss = -criterion(outputs, original_images)
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                adversarial_noise = torch.clamp(adversarial_noise, -self.outside_eps, self.outside_eps)
            adversarial_noise.requires_grad = True

        print(f"Outside steps:")

        max_eps = torch.max(torch.abs(adversarial_noise))
        print(f"Max epsilon: {max_eps.item()}")

        assert torch.all(torch.abs(adversarial_noise) <= self.outside_eps*1.01), "Adversarial noise exceeds epsilon constraint"
        outputs = self.model(torch.clamp(original_images + adversarial_noise, 0, 1))
        outputs = torch.clamp(outputs, 0, 1)
        l1_distance = torch.linalg.vector_norm(outputs - original_images, ord=1,)
        l2_distance = torch.linalg.vector_norm(outputs - original_images, ord=2,)
        linf_distance = torch.linalg.vector_norm(outputs - original_images, ord=float('inf'))
        print(f"L1 distance: {l1_distance.item()}")
        print(f"L2 distance: {l2_distance.item()}")
        print(f"Linf distance: {linf_distance.item()}")
            
        for step in range(self.steps):
            optimizer.zero_grad()
            
            defaults = optimizer.param_groups[0].copy()
            defaults.pop('params')
            optimizer.param_groups = [{'params': [adversarial_noise],**defaults}]
            outputs = self.model(original_images + adversarial_noise)

            loss = -criterion(outputs, original_images)
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                self.outside_eps -= self.eps_step_size
                adversarial_noise = torch.clamp(adversarial_noise, -self.outside_eps, self.outside_eps)
            adversarial_noise.requires_grad = True
            print(f"Step {step+1}/{self.steps}, Loss: {loss.item()}")
        adversarial_noise = torch.clamp(adversarial_noise, -self.eps, self.eps)
        print(f"Outside epsilon: {self.outside_eps}")
        print(f"Standard Epsilon attack:")
        # Ensure the adversarial noise is within the epsilon constraint
        max_eps = torch.max(torch.abs(adversarial_noise))
        print(f"Max epsilon: {max_eps.item()}")

        assert torch.all(torch.abs(adversarial_noise) <= self.eps*1.01), "Adversarial noise exceeds epsilon constraint"
        outputs = self.model(torch.clamp(original_images + adversarial_noise, 0, 1))
        outputs = torch.clamp(outputs, 0, 1)
        l1_distance = torch.linalg.vector_norm(outputs - original_images, ord=1,)
        l2_distance = torch.linalg.vector_norm(outputs - original_images, ord=2,)
        linf_distance = torch.linalg.vector_norm(outputs - original_images, ord=float('inf'))
        print(f"L1 distance: {l1_distance.item()}")
        print(f"L2 distance: {l2_distance.item()}")
        print(f"Linf distance: {linf_distance.item()}")

        self._save_images(outputs, "debug_images")

        

